{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "02623ad8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#全局常量\n",
    "repo_id = 'lansinuote/diffusion.3.dream_booth'\n",
    "checkpoint = 'runwayml/stable-diffusion-v1-5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "315ad307",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration lansinuote--diffusion.3.dream_booth-4344a7d7501be7f1\n",
      "Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.3.dream_booth-4344a7d7501be7f1/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(Dataset({\n",
       "     features: ['image', 'text'],\n",
       "     num_rows: 5\n",
       " }),\n",
       " {'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=2469x2558>,\n",
       "  'text': 'a photo of little dog'})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "#直接使用我处理好的数据集\n",
    "dataset = load_dataset(path=repo_id, split='train')\n",
    "\n",
    "dataset, dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2af94061",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pixel_values torch.Size([3, 512, 512]) torch.float32\n",
      "input_ids torch.Size([77]) torch.int64\n",
      "attention_mask torch.Size([77]) torch.int64\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['image', 'text'],\n",
       "    num_rows: 5\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torchvision\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "#数据增强\n",
    "compose = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.Resize(\n",
    "        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),\n",
    "    torchvision.transforms.RandomCrop(512),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Normalize([0.5], [0.5]),\n",
    "])\n",
    "\n",
    "#文字编码\n",
    "tokenizer = AutoTokenizer.from_pretrained(checkpoint,\n",
    "                                          subfolder='tokenizer',\n",
    "                                          use_fast=False)\n",
    "\n",
    "\n",
    "def f(data):\n",
    "    #图像编码\n",
    "    pixel_values = compose(data['image'][0]).unsqueeze(dim=0)\n",
    "\n",
    "    #文字编码\n",
    "    #77 = tokenizer.model_max_length\n",
    "    tokens = tokenizer(data['text'][0],\n",
    "                       truncation=True,\n",
    "                       padding='max_length',\n",
    "                       max_length=77,\n",
    "                       return_tensors='pt')\n",
    "\n",
    "    return {\n",
    "        'pixel_values': pixel_values,\n",
    "        'input_ids': tokens.input_ids,\n",
    "        'attention_mask': tokens.attention_mask\n",
    "    }\n",
    "\n",
    "\n",
    "dataset = dataset.with_transform(f)\n",
    "\n",
    "for k, v in dataset[0].items():\n",
    "    print(k, v.shape, v.dtype)\n",
    "\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c07c2a12",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pixel_values torch.Size([1, 3, 512, 512]) torch.float32\n",
      "input_ids torch.Size([1, 77]) torch.int64\n",
      "attention_mask torch.Size([1, 77]) torch.int64\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "5"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "loader = torch.utils.data.DataLoader(dataset,\n",
    "                                     batch_size=1,\n",
    "                                     shuffle=True,\n",
    "                                     collate_fn=None)\n",
    "\n",
    "for k, v in next(iter(loader)).items():\n",
    "    print(k, v.shape, v.dtype)\n",
    "\n",
    "len(loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e0459d36",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoder 12306.048\n",
      "vae 8365.3863\n",
      "unet 85952.0964\n"
     ]
    }
   ],
   "source": [
    "from transformers.models.clip.modeling_clip import CLIPTextModel\n",
    "from diffusers import AutoencoderKL, UNet2DConditionModel\n",
    "\n",
    "encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')\n",
    "vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')\n",
    "unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')\n",
    "\n",
    "vae.requires_grad_(False)\n",
    "encoder.requires_grad_(False)\n",
    "unet.requires_grad_(False)\n",
    "\n",
    "\n",
    "def print_model_size(name, model):\n",
    "    print(name, sum(i.numel() for i in model.parameters()) / 10000)\n",
    "\n",
    "\n",
    "print_model_size('encoder', encoder)\n",
    "print_model_size('vae', vae)\n",
    "print_model_size('unet', unet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c13480bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor 320 None\n",
      "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor 320 768\n",
      "down_blocks.0.attentions.1.transformer_blocks.0.attn1.processor 320 None\n",
      "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor 320 768\n",
      "down_blocks.1.attentions.0.transformer_blocks.0.attn1.processor 640 None\n",
      "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor 640 768\n",
      "down_blocks.1.attentions.1.transformer_blocks.0.attn1.processor 640 None\n",
      "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor 640 768\n",
      "down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor 1280 None\n",
      "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor 1280 768\n",
      "down_blocks.2.attentions.1.transformer_blocks.0.attn1.processor 1280 None\n",
      "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor 1280 768\n",
      "up_blocks.1.attentions.0.transformer_blocks.0.attn1.processor 1280 None\n",
      "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor 1280 768\n",
      "up_blocks.1.attentions.1.transformer_blocks.0.attn1.processor 1280 None\n",
      "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor 1280 768\n",
      "up_blocks.1.attentions.2.transformer_blocks.0.attn1.processor 1280 None\n",
      "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor 1280 768\n",
      "up_blocks.2.attentions.0.transformer_blocks.0.attn1.processor 640 None\n",
      "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor 640 768\n",
      "up_blocks.2.attentions.1.transformer_blocks.0.attn1.processor 640 None\n",
      "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor 640 768\n",
      "up_blocks.2.attentions.2.transformer_blocks.0.attn1.processor 640 None\n",
      "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor 640 768\n",
      "up_blocks.3.attentions.0.transformer_blocks.0.attn1.processor 320 None\n",
      "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor 320 768\n",
      "up_blocks.3.attentions.1.transformer_blocks.0.attn1.processor 320 None\n",
      "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor 320 768\n",
      "up_blocks.3.attentions.2.transformer_blocks.0.attn1.processor 320 None\n",
      "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor 320 768\n",
      "mid_block.attentions.0.transformer_blocks.0.attn1.processor 1280 None\n",
      "mid_block.attentions.0.transformer_blocks.0.attn2.processor 1280 768\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt39/lib/python3.9/site-packages/diffusers/models/cross_attention.py:30: FutureWarning: Importing from cross_attention is deprecated. Please import from diffusers.models.attention_processor instead.\n",
      "  deprecate(\n",
      "/root/anaconda3/envs/pt39/lib/python3.9/site-packages/diffusers/models/cross_attention.py:58: FutureWarning: LoRACrossAttnProcessor is deprecated and will be removed in `0.18.0`. Please use `from diffusers.models.attention_processor import LoRAAttnProcessor instead.\n",
      "  deprecate(\"cross_attention\", \"0.18.0\", deprecation_message, standard_warn=False)\n"
     ]
    }
   ],
   "source": [
    "#这个类不想测了,总之就是在做注意力的调整\n",
    "from diffusers.models.cross_attention import LoRACrossAttnProcessor\n",
    "\n",
    "\n",
    "def set_processors():\n",
    "    processors = {}\n",
    "\n",
    "    #遍历unet的所有层,找出所有有set_processor属性的层,每一个都组装成lora层\n",
    "    for name in unet.attn_processors.keys():\n",
    "        #768 = unet.config.cross_attention_dim\n",
    "        cross_attention_dim = 768\n",
    "        if name.endswith('attn1.processor'):\n",
    "            cross_attention_dim = None\n",
    "\n",
    "        #1280 = unet.config.block_out_channels[-1]\n",
    "        hidden_size = 1280\n",
    "\n",
    "        if name.startswith('up_blocks'):\n",
    "            #取层名字中的第一个数字\n",
    "            #p_blocks.1.attentions.0.transformer_blocks.0.attn1.processor -> 1\n",
    "            block_id = int(name[10])\n",
    "            hidden_size = list(reversed(\n",
    "                unet.config.block_out_channels))[block_id]\n",
    "\n",
    "        if name.startswith('down_blocks'):\n",
    "            #取层名字中的第一个数字\n",
    "            #down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor -> 2\n",
    "            block_id = int(name[12])\n",
    "            hidden_size = unet.config.block_out_channels[block_id]\n",
    "\n",
    "        processors[name] = LoRACrossAttnProcessor(hidden_size,\n",
    "                                                  cross_attention_dim)\n",
    "\n",
    "        print(name, hidden_size, cross_attention_dim)\n",
    "\n",
    "    #把上面组装好的字典,设置到unet的层当中\n",
    "    unet.set_attn_processor(processors)\n",
    "\n",
    "\n",
    "set_processors()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2690fbdf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lora_layers = torch.nn.ModuleList(unet.attn_processors.values())\n",
    "\n",
    "len(lora_layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c6f4b455",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DDPMScheduler {\n",
       "   \"_class_name\": \"DDPMScheduler\",\n",
       "   \"_diffusers_version\": \"0.15.1\",\n",
       "   \"beta_end\": 0.012,\n",
       "   \"beta_schedule\": \"scaled_linear\",\n",
       "   \"beta_start\": 0.00085,\n",
       "   \"clip_sample\": false,\n",
       "   \"clip_sample_range\": 1.0,\n",
       "   \"dynamic_thresholding_ratio\": 0.995,\n",
       "   \"num_train_timesteps\": 1000,\n",
       "   \"prediction_type\": \"epsilon\",\n",
       "   \"sample_max_value\": 1.0,\n",
       "   \"set_alpha_to_one\": false,\n",
       "   \"skip_prk_steps\": true,\n",
       "   \"steps_offset\": 1,\n",
       "   \"thresholding\": false,\n",
       "   \"trained_betas\": null,\n",
       "   \"variance_type\": \"fixed_small\"\n",
       " },\n",
       " AdamW (\n",
       " Parameter Group 0\n",
       "     amsgrad: False\n",
       "     betas: (0.9, 0.999)\n",
       "     capturable: False\n",
       "     eps: 1e-08\n",
       "     foreach: None\n",
       "     lr: 0.0001\n",
       "     maximize: False\n",
       "     weight_decay: 0.01\n",
       " ),\n",
       " MSELoss())"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from diffusers import DDPMScheduler\n",
    "\n",
    "scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')\n",
    "\n",
    "optimizer = torch.optim.AdamW(lora_layers.parameters(),\n",
    "                              lr=1e-4,\n",
    "                              betas=(0.9, 0.999),\n",
    "                              weight_decay=0.01,\n",
    "                              eps=1e-8)\n",
    "\n",
    "criterion = torch.nn.MSELoss()\n",
    "\n",
    "scheduler, optimizer, criterion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b3effea6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.2743, grad_fn=<MseLossBackward0>)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_loss(data):\n",
    "    device = data['input_ids'].device\n",
    "\n",
    "    #编码文字\n",
    "    #[1, 77] -> [1, 77, 768]\n",
    "    out_encoder = encoder(input_ids=data['input_ids'],\n",
    "                          attention_mask=data['attention_mask'])[0]\n",
    "\n",
    "    #vae计算特征图\n",
    "    #[1, 3, 512, 512] -> [1, 4, 64, 64]\n",
    "    out_vae = vae.encode(data['pixel_values']).latent_dist.sample().detach()\n",
    "\n",
    "    #0.18215 = vae.config.scaling_factor\n",
    "    out_vae = out_vae * 0.18215\n",
    "\n",
    "    #随机噪声\n",
    "    #[1, 4, 64, 64]\n",
    "    noise = torch.randn_like(out_vae)\n",
    "\n",
    "    #随机噪声步\n",
    "    #1000 = scheduler.config.num_train_timesteps\n",
    "    #1 = b\n",
    "    noise_step = torch.randint(0, 1000, (1, ), device=device).long()\n",
    "\n",
    "    #添加噪声\n",
    "    #[1, 4, 64, 64]\n",
    "    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)\n",
    "\n",
    "    #unet从噪声图中把噪声计算出来\n",
    "    #[1, 4, 64, 64],[1, 77, 768] -> [1, 4, 64, 64]\n",
    "    out_unet = unet(out_vae_noise, noise_step, out_encoder).sample\n",
    "\n",
    "    return criterion(out_unet, noise)\n",
    "\n",
    "\n",
    "get_loss({\n",
    "    'pixel_values': torch.randn(1, 3, 512, 512),\n",
    "    'input_ids': torch.ones(1, 77).long(),\n",
    "    'attention_mask': torch.ones(1, 77).long()\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "13a8c245",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.47430265601724386\n",
      "10 6.008332805009559\n",
      "20 5.410161764128134\n",
      "30 4.312515129800886\n",
      "40 5.0653923387872055\n",
      "50 4.800978829618543\n",
      "60 4.101198305841535\n",
      "70 3.55872066388838\n",
      "80 5.382800737861544\n",
      "90 3.1098747176583856\n",
      "100 5.2366352655226365\n",
      "110 3.251381019828841\n",
      "120 6.391375361359678\n",
      "130 3.4201086544198915\n",
      "140 3.95328053063713\n",
      "150 4.018299808958545\n",
      "160 4.696397997322492\n",
      "170 3.5135789506603032\n",
      "180 4.713769806083292\n",
      "190 3.8237785125384107\n"
     ]
    }
   ],
   "source": [
    "from diffusers import DiffusionPipeline\n",
    "\n",
    "\n",
    "def train():\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    unet.to(device)\n",
    "    vae.to(device)\n",
    "    encoder.to(device)\n",
    "    unet.train()\n",
    "\n",
    "    loss_sum = 0\n",
    "    for epoch in range(200):\n",
    "        for i, data in enumerate(loader):\n",
    "            for k in data.keys():\n",
    "                data[k] = data[k].to(device)\n",
    "            loss = get_loss(data)\n",
    "\n",
    "            loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(lora_layers.parameters(), 1.0)\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            loss_sum += loss.item()\n",
    "\n",
    "        if epoch % 10 == 0:\n",
    "            print(epoch, loss_sum)\n",
    "            loss_sum = 0\n",
    "\n",
    "    unet.save_attn_procs('./save')\n",
    "\n",
    "\n",
    "train()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
